

compute_mse <- function(dat, Y_test, beta, opt){
 
 reparam = opt$reparam
 estimator = opt$estimator 
 
 beta_y = beta$beta_y
 beta_m = beta$beta_m
 beta_a = beta$beta_a
 
 dat_a0m0 = process_data(dat, a = 0, m = 0)
 dat_a0m1 = process_data(dat, a = 0, m = 1)
 dat_a1m0 = process_data(dat, a = 1, m = 0)
 dat_a1m1 = process_data(dat, a = 1, m = 1)
 dat_am0 = process_data(dat, a = dat$A, m = 0)
 dat_am1 = process_data(dat, a = dat$A, m = 1)
 dat_a0m = process_data(dat, a = 0, m = dat$M)
 dat_a1m = process_data(dat, a = 1, m = dat$M)
 
 
 # +++++++++++++++++++++++++++++++++
 # G-formula: sum out {M}
 # +++++++++++++++++++++++++++++++++
 if (estimator == "G-formula"){ 
  
  # p*( M | A, C)
  idx_m = c(1, match(attributes(beta_m)$names[-1], colnames(dat)))
  p_m1 = 1/(1 + exp(-as.matrix(dat[, idx_m])%*%beta_m))
  p_m0 = 1 - p_m1
  
  # E*[Y | M = m, A, C]
  idx_y = c(1, match(attributes(beta_y)$names[-1], colnames(dat)))
  Yhat_m1 = as.matrix(dat_am1[, idx_y])%*%beta_y
  Yhat_m0 = as.matrix(dat_am0[, idx_y])%*%beta_y

  # E*[Y | A, C] = \sum_M E*[Y | A, C, M] p*(M | A, C)
  Y_hat = Yhat_m1*p_m1 + Yhat_m0*p_m0 
  
 }
 
 # +++++++++++++++++++++++++++++++++
 # IPW and AIPW: sum out {A, M}
 # +++++++++++++++++++++++++++++++++
 if (estimator %in% c("IPW", "AIPW")){
  
  # p*( A | C)
  idx_a = c(1, match(attributes(beta_a)$names[-1], colnames(dat)))
  p_a1 = 1/(1 + exp(-as.matrix(dat[, idx_a])%*%beta_a))
  p_a0 = 1 - p_a1
  
  # p*( M | A, C)
  idx_m = c(1, match(attributes(beta_m)$names[-1], colnames(dat)))
  p_m1 = 1/(1 + exp(-as.matrix(dat[, idx_m])%*%beta_m))
  p_m0 = 1 - p_m1
  
  # E*[Y | M = m, A, C]
  idx_y = c(1, match(attributes(beta_y)$names[-1], colnames(dat)))
  y_a0m0 = dat_a0m0[, idx_y]%*%beta_y
  y_a0m1 = dat_a0m1[, idx_y]%*%beta_y
  y_a1m0 = dat_a1m0[, idx_y]%*%beta_y
  y_a1m1 = dat_a1m1[, idx_y]%*%beta_y
  
  # E*[Y | C] = \sum_{A,M} E*[Y | A, C, M] p*(M | A, C)*p(A | C)
  Y_hat = y_a0m0*p_m0*p_a0 + y_a0m1*p_m1*p_a0 + y_a1m0*p_m0*p_a1 + y_a1m1*p_m1*p_a1
 }
 
 # +++++++++++++++++++++++++++++++++
 # Mixed: sum out {A}
 # +++++++++++++++++++++++++++++++++
 if (estimator == "Mixed"){
  
  # p*( A | C)
  idx_a = c(1, match(attributes(beta_a)$names[-1], colnames(dat)))
  p_a1 = 1/(1 + exp(-as.matrix(dat[, idx_a])%*%beta_a))
  p_a0 = 1 - p_a1
  
  # p*( M | A, C)
  idx_m = c(1, match(attributes(beta_m)$names[-1], colnames(dat)))
  p_m1a0 = 1/(1 + exp(-dat_a0m1[, idx_m]%*%beta_m))
  p_m1a1 = 1/(1 + exp(-dat_a1m1[, idx_m]%*%beta_m))
  p_ma0 = p_m1a0 
  p_ma0[dat$M == 0] =  1 - p_m1a0[dat$M == 0]
  p_ma1 = p_m1a1 
  p_ma1[dat$M == 0] =  1 - p_m1a1[dat$M == 0]
  
  # E*[Y | M = m, A, C]
  idx_y = c(1, match(attributes(beta_y)$names[-1], colnames(dat)))
  y_a0m = dat_a0m[, idx_y]%*%beta_y
  y_a1m = dat_a1m[, idx_y]%*%beta_y
  
  # E*[Y | M, C] = \sum_{A} E*[Y | A, C, M] * { p(M | A, C)*p(A | C) } / { \sum_{A} p(M | A, C) p(A | C) }
  Y_hat = {y_a0m*p_ma0*p_a0 + y_a1m*p_ma1*p_a1} / {p_ma0*p_a0 + p_ma1*p_a1}
 }
 
 MSE = mean((Y_test - Y_hat)^2)
 
 return(MSE)
}
